import os
import time
import copy
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from google.colab import drive
import pandas as pd
import seaborn as sns
import cv2
import random
import matplotlib.image as mpimg
from sklearn.metrics import confusion_matrix, classification_report
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda:0
cpu_count = os.cpu_count()
cpu_count
2
drive.mount('/content/drive')
zip_path = "/content/drive/My Drive/GarbageDataSets/GarbageData.zip"
Mounted at /content/drive
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("/content/data")
base_dir = "/content/data/garbage_classification/"
classes = [
"white-glass",
"trash",
"shoes",
"plastic",
"paper",
"metal",
"green-glass",
"clothes",
"cardboard",
"brown-glass",
"biological",
"battery"
]
data = pd.DataFrame(columns=['path', 'filename', 'target'])
IMG_SIZE = 224
# Loop through each class
for category in classes:
class_path = os.path.join(base_dir, category)
# Get all image files in the class directory
photos = [f for f in os.listdir(class_path)]
# Create temp DataFrame for the current class
temp = pd.DataFrame({
'path': [os.path.join(class_path, photo) for photo in photos],
'filename': [f"{category}/{photo}" for photo in photos],
'target': category
})
# Append to main DataFrame
data = pd.concat([data, temp], ignore_index=True)
# Verify
print(f"Total samples: {len(data)}")
print(data.head())
Total samples: 15515
path \
0 /content/data/garbage_classification/white-gla...
1 /content/data/garbage_classification/white-gla...
2 /content/data/garbage_classification/white-gla...
3 /content/data/garbage_classification/white-gla...
4 /content/data/garbage_classification/white-gla...
filename target
0 white-glass/white-glass605.jpg white-glass
1 white-glass/white-glass440.jpg white-glass
2 white-glass/white-glass459.jpg white-glass
3 white-glass/white-glass50.jpg white-glass
4 white-glass/white-glass331.jpg white-glass
data
| path | filename | target | |
|---|---|---|---|
| 0 | /content/data/garbage_classification/white-gla... | white-glass/white-glass605.jpg | white-glass |
| 1 | /content/data/garbage_classification/white-gla... | white-glass/white-glass440.jpg | white-glass |
| 2 | /content/data/garbage_classification/white-gla... | white-glass/white-glass459.jpg | white-glass |
| 3 | /content/data/garbage_classification/white-gla... | white-glass/white-glass50.jpg | white-glass |
| 4 | /content/data/garbage_classification/white-gla... | white-glass/white-glass331.jpg | white-glass |
| ... | ... | ... | ... |
| 15510 | /content/data/garbage_classification/battery/b... | battery/battery583.jpg | battery |
| 15511 | /content/data/garbage_classification/battery/b... | battery/battery711.jpg | battery |
| 15512 | /content/data/garbage_classification/battery/b... | battery/battery601.jpg | battery |
| 15513 | /content/data/garbage_classification/battery/b... | battery/battery114.jpg | battery |
| 15514 | /content/data/garbage_classification/battery/b... | battery/battery848.jpg | battery |
15515 rows × 3 columns
data['target'].unique()
array(['white-glass', 'trash', 'shoes', 'plastic', 'paper', 'metal',
'green-glass', 'clothes', 'cardboard', 'brown-glass', 'biological',
'battery'], dtype=object)
for i in range (10):
random_row = random.randint(0, len(data)-1)
sample = data.iloc[random_row]
image = mpimg.imread(sample['path'])
plt.imshow(image)
print(sample['path'])
plt.show()
/content/data/garbage_classification/clothes/clothes1837.jpg
/content/data/garbage_classification/paper/paper291.jpg
/content/data/garbage_classification/battery/battery620.jpg
/content/data/garbage_classification/clothes/clothes2250.jpg
/content/data/garbage_classification/biological/biological321.jpg
/content/data/garbage_classification/shoes/shoes1741.jpg
/content/data/garbage_classification/white-glass/white-glass624.jpg
/content/data/garbage_classification/clothes/clothes3251.jpg
/content/data/garbage_classification/shoes/shoes1140.jpg
/content/data/garbage_classification/clothes/clothes418.jpg
total_counts = 0
for category in os.listdir(base_dir):
count_class = 0
# Add slash between path and category
category_full_path = os.path.join(base_dir, category)
# Skip if not a directory
if not os.path.isdir(category_full_path):
continue
for photo in os.listdir(category_full_path):
count_class += 1
total_counts += 1
print(f"{category} has {count_class} photos")
print(f"\nTotal photos: {total_counts}")
shoes has 1977 photos white-glass has 775 photos green-glass has 629 photos battery has 945 photos cardboard has 891 photos biological has 985 photos clothes has 5325 photos paper has 1050 photos brown-glass has 607 photos plastic has 865 photos trash has 697 photos metal has 769 photos Total photos: 15515
data['target'].value_counts()
| count | |
|---|---|
| target | |
| clothes | 5325 |
| shoes | 1977 |
| paper | 1050 |
| biological | 985 |
| battery | 945 |
| cardboard | 891 |
| plastic | 865 |
| white-glass | 775 |
| metal | 769 |
| trash | 697 |
| green-glass | 629 |
| brown-glass | 607 |
data
| path | filename | target | |
|---|---|---|---|
| 0 | /content/data/garbage_classification/white-gla... | white-glass/white-glass605.jpg | white-glass |
| 1 | /content/data/garbage_classification/white-gla... | white-glass/white-glass440.jpg | white-glass |
| 2 | /content/data/garbage_classification/white-gla... | white-glass/white-glass459.jpg | white-glass |
| 3 | /content/data/garbage_classification/white-gla... | white-glass/white-glass50.jpg | white-glass |
| 4 | /content/data/garbage_classification/white-gla... | white-glass/white-glass331.jpg | white-glass |
| ... | ... | ... | ... |
| 15510 | /content/data/garbage_classification/battery/b... | battery/battery583.jpg | battery |
| 15511 | /content/data/garbage_classification/battery/b... | battery/battery711.jpg | battery |
| 15512 | /content/data/garbage_classification/battery/b... | battery/battery601.jpg | battery |
| 15513 | /content/data/garbage_classification/battery/b... | battery/battery114.jpg | battery |
| 15514 | /content/data/garbage_classification/battery/b... | battery/battery848.jpg | battery |
15515 rows × 3 columns
imgPaths = data['path']
fig, axs = plt.subplots(3, 8, figsize=(25, 10))
axs = axs.flatten()
for ax,imgPath in zip(axs , imgPaths):
label = str(imgPath).split('/')[-2] # extract label of an imgae from a path
img = cv2.imread(imgPath)
ax.imshow(img)
ax.set_title(label)
ax.axis('off')
plt.tight_layout()
plt.show()
plt.figure(figsize=(10, 5))
sns.countplot(x="target", data=data, palette='Blues')
plt.xticks(rotation=90)
plt.title('Categories')
plt.show()
<ipython-input-19-3a4b8330395b>:3: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect. sns.countplot(x="target", data=data, palette='Blues')
def create_dataset_splits():
# Dictionary to store image paths for each class
class_images = {cls: [] for cls in classes}
# Collect all image paths by class
for class_name in classes:
class_dir = os.path.join(base_dir, class_name)
if os.path.exists(class_dir):
for img_name in os.listdir(class_dir):
if img_name.endswith(('.jpg', '.jpeg', '.png')):
img_path = os.path.join(class_dir, img_name)
class_images[class_name].append(img_path)
else:
print(f"Warning: Directory not found for class {class_name}")
# Split into train and validation sets (80/20 split)
train_images = []
val_images = []
train_labels = []
val_labels = []
for class_idx, class_name in enumerate(classes):
images = class_images[class_name]
print(f"Class {class_name}: {len(images)} images")
if len(images) > 0:
# Split 80% train, 20% validation
train_imgs, val_imgs = train_test_split(
images, test_size=0.2, random_state=42
)
train_images.extend(train_imgs)
val_images.extend(val_imgs)
train_labels.extend([class_idx] * len(train_imgs))
val_labels.extend([class_idx] * len(val_imgs))
return {
'train': (train_images, train_labels),
'val': (val_images, val_labels)
}
# Create the dataset splits
dataset_splits = create_dataset_splits()
Class white-glass: 775 images Class trash: 697 images Class shoes: 1977 images Class plastic: 865 images Class paper: 1050 images Class metal: 769 images Class green-glass: 629 images Class clothes: 5325 images Class cardboard: 891 images Class brown-glass: 607 images Class biological: 985 images Class battery: 945 images
class GarbageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Create datasets
image_datasets = {
x: GarbageDataset(
dataset_splits[x][0],
dataset_splits[x][1],
data_transforms[x]
) for x in ['train', 'val']
}
# Create dataloaders
dataloaders = {
x: DataLoader(
image_datasets[x],
batch_size=32,
shuffle=True,
num_workers=2
) for x in ['train', 'val']
}
# Get dataset sizes
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
num_classes = len(classes)
print(f"Total images: {dataset_sizes['train'] + dataset_sizes['val']}")
print(f"Training images: {dataset_sizes['train']}")
print(f"Validation images: {dataset_sizes['val']}")
print(f"Number of classes: {num_classes}")
Total images: 15515 Training images: 12409 Validation images: 3106 Number of classes: 12
model_ft = models.resnet50(weights='IMAGENET1K_V2')
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth 100%|██████████| 97.8M/97.8M [00:00<00:00, 217MB/s]
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes) # num_classes = 12
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001, weight_decay=1e-4)
# Learning rate scheduler to reduce LR by factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
# For tracking metrics
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward
# Track history only if in train phase
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# Step the scheduler if in training phase
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Save history for plotting
if phase == 'train':
train_loss_history.append(epoch_loss)
train_acc_history.append(epoch_acc.item())
else:
val_loss_history.append(epoch_loss)
val_acc_history.append(epoch_acc.item())
# Deep copy the model if best accuracy achieved
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
# Save checkpoint to Google Drive (always use the same filename to keep only the best model)
checkpoint_path = '/content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'class_names': classes
}, checkpoint_path)
print(f"Best model updated at epoch {epoch} with accuracy {best_acc:.4f}")
print(f"Checkpoint saved to Drive: {checkpoint_path}")
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
# Load best model weights
model.load_state_dict(best_model_wts)
# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Train Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig('/content/training_curves.png')
# Also save to Drive
plt.savefig('/content/drive/My Drive/GarbageDataSets/garbage_training_curves.png')
plt.show()
return model
num_epochs = 25
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=num_epochs)
# Save the final model to Google Drive
final_model_path = '/content/drive/My Drive/GarbageDataSets/garbage_resnet50_final_model.pth'
torch.save({
'model_state_dict': model_ft.state_dict(),
'class_names': classes,
'num_classes': num_classes
}, final_model_path)
print(f"Final model saved to Drive: {final_model_path}")
Epoch 0/24 ---------- train Loss: 0.9384 Acc: 0.7120 val Loss: 0.4874 Acc: 0.8609 Best model updated at epoch 0 with accuracy 0.8609 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 1/24 ---------- train Loss: 0.6394 Acc: 0.7941 val Loss: 0.4370 Acc: 0.8654 Best model updated at epoch 1 with accuracy 0.8654 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 2/24 ---------- train Loss: 0.6072 Acc: 0.8041 val Loss: 0.4346 Acc: 0.8722 Best model updated at epoch 2 with accuracy 0.8722 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 3/24 ---------- train Loss: 0.5806 Acc: 0.8142 val Loss: 0.5161 Acc: 0.8558 Epoch 4/24 ---------- train Loss: 0.5434 Acc: 0.8273 val Loss: 0.4079 Acc: 0.8793 Best model updated at epoch 4 with accuracy 0.8793 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 5/24 ---------- train Loss: 0.5478 Acc: 0.8247 val Loss: 0.3737 Acc: 0.8847 Best model updated at epoch 5 with accuracy 0.8847 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 6/24 ---------- train Loss: 0.5343 Acc: 0.8281 val Loss: 0.3489 Acc: 0.8941 Best model updated at epoch 6 with accuracy 0.8941 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 7/24 ---------- train Loss: 0.3579 Acc: 0.8845 val Loss: 0.2070 Acc: 0.9420 Best model updated at epoch 7 with accuracy 0.9420 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 8/24 ---------- train Loss: 0.2935 Acc: 0.9026 val Loss: 0.1989 Acc: 0.9430 Best model updated at epoch 8 with accuracy 0.9430 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 9/24 ---------- train Loss: 0.2712 Acc: 0.9075 val Loss: 0.1890 Acc: 0.9469 Best model updated at epoch 9 with accuracy 0.9469 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 10/24 ---------- train Loss: 0.2582 Acc: 0.9153 val Loss: 0.1797 Acc: 0.9488 Best model updated at epoch 10 with accuracy 0.9488 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 11/24 ---------- train Loss: 0.2401 Acc: 0.9197 val Loss: 0.1790 Acc: 0.9475 Epoch 12/24 ---------- train Loss: 0.2330 Acc: 0.9223 val Loss: 0.1815 Acc: 0.9475 Epoch 13/24 ---------- train Loss: 0.2299 Acc: 0.9242 val Loss: 0.1729 Acc: 0.9533 Best model updated at epoch 13 with accuracy 0.9533 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 14/24 ---------- train Loss: 0.2029 Acc: 0.9332 val Loss: 0.1641 Acc: 0.9556 Best model updated at epoch 14 with accuracy 0.9556 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 15/24 ---------- train Loss: 0.2014 Acc: 0.9321 val Loss: 0.1594 Acc: 0.9581 Best model updated at epoch 15 with accuracy 0.9581 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 16/24 ---------- train Loss: 0.1889 Acc: 0.9370 val Loss: 0.1592 Acc: 0.9565 Epoch 17/24 ---------- train Loss: 0.1887 Acc: 0.9382 val Loss: 0.1580 Acc: 0.9569 Epoch 18/24 ---------- train Loss: 0.1746 Acc: 0.9425 val Loss: 0.1541 Acc: 0.9581 Epoch 19/24 ---------- train Loss: 0.1716 Acc: 0.9430 val Loss: 0.1535 Acc: 0.9591 Best model updated at epoch 19 with accuracy 0.9591 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 20/24 ---------- train Loss: 0.1710 Acc: 0.9427 val Loss: 0.1554 Acc: 0.9578 Epoch 21/24 ---------- train Loss: 0.1733 Acc: 0.9416 val Loss: 0.1540 Acc: 0.9598 Best model updated at epoch 21 with accuracy 0.9598 Checkpoint saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_best.pth Epoch 22/24 ---------- train Loss: 0.1722 Acc: 0.9450 val Loss: 0.1557 Acc: 0.9598 Epoch 23/24 ---------- train Loss: 0.1671 Acc: 0.9457 val Loss: 0.1551 Acc: 0.9594 Epoch 24/24 ---------- train Loss: 0.1734 Acc: 0.9435 val Loss: 0.1550 Acc: 0.9578 Training complete in 61m 15s Best val Acc: 0.959755
Final model saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_final_model.pth
final_model_path = '/content/drive/My Drive/GarbageDataSets/garbage_resnet50_final_model.pth'
torch.save({
'model_state_dict': model_ft.state_dict(),
'class_names': classes,
'num_classes': num_classes
}, final_model_path)
print(f"Final model saved to Drive: {final_model_path}")
Final model saved to Drive: /content/drive/My Drive/GarbageDataSets/garbage_resnet50_final_model.pth
# Function to visualize model predictions
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure(figsize=(15, 8))
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//3 + 1, 3, images_so_far)
ax.axis('off')
ax.set_title(f'Predicted: {classes[preds[j]]} | Actual: {classes[labels[j]]}')
# Convert image for display
img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = std * img + mean
img = np.clip(img, 0, 1)
ax.imshow(img)
if images_so_far == num_images:
model.train(mode=was_training)
plt.tight_layout()
plt.savefig('/content/sample_predictions.png')
plt.savefig('/content/drive/My Drive/GarbageDataSets/garbage_sample_predictions.png')
return
model.train(mode=was_training)
visualize_model(model_ft)
def evaluate_model(model):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in dataloaders['val']:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Calculate and plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('/content/confusion_matrix.png')
plt.savefig('/content/drive/My Drive/GarbageDataSets/garbage_confusion_matrix.png')
# Print classification report with per-class metrics
report = classification_report(all_labels, all_preds, target_names=classes)
print("Classification Report:")
print(report)
# Save report to a text file
with open('/content/drive/My Drive/GarbageDataSets/garbage_classification_report.txt', 'w') as f:
f.write(report)
# Return the accuracy for each class
class_accuracies = {}
for i, class_name in enumerate(classes):
# Calculate class accuracy from confusion matrix
class_correct = cm[i, i]
class_total = np.sum(cm[i, :])
class_accuracy = class_correct / class_total
class_accuracies[class_name] = class_accuracy
print(f"Accuracy for class {class_name}: {class_accuracy:.4f}")
return class_accuracies
class_accuracies = evaluate_model(model_ft)
Classification Report:
precision recall f1-score support
white-glass 0.97 0.91 0.94 155
trash 0.96 0.97 0.97 140
shoes 0.95 0.96 0.95 396
plastic 0.91 0.90 0.90 173
paper 0.97 0.96 0.96 210
metal 0.85 0.92 0.88 154
green-glass 0.96 0.97 0.96 126
clothes 0.99 0.99 0.99 1065
cardboard 0.97 0.93 0.95 179
brown-glass 0.95 0.93 0.94 122
biological 0.98 0.94 0.96 197
battery 0.94 0.95 0.94 189
accuracy 0.96 3106
macro avg 0.95 0.94 0.95 3106
weighted avg 0.96 0.96 0.96 3106
Accuracy for class white-glass: 0.9097
Accuracy for class trash: 0.9714
Accuracy for class shoes: 0.9646
Accuracy for class plastic: 0.8960
Accuracy for class paper: 0.9571
Accuracy for class metal: 0.9156
Accuracy for class green-glass: 0.9683
Accuracy for class clothes: 0.9925
Accuracy for class cardboard: 0.9330
Accuracy for class brown-glass: 0.9344
Accuracy for class biological: 0.9391
Accuracy for class battery: 0.9524
plt.figure(figsize=(12, 6))
plt.bar(class_accuracies.keys(), class_accuracies.values())
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.title('Per-Class Accuracy')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('/content/per_class_accuracy.png')
plt.savefig('/content/drive/My Drive/GarbageDataSets/garbage_per_class_accuracy.png')
plt.show()
def predict_image(model, image_path):
model.eval()
# Load and preprocess the image
img = Image.open(image_path).convert('RGB')
transform = data_transforms['val']
img_tensor = transform(img).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, preds = torch.max(outputs, 1)
predicted_class = classes[preds[0]]
# Get confidence scores
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
confidence = probs[preds[0]].item() * 100
# Display the image with prediction
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.axis('off')
plt.title(f'Predicted: {predicted_class}\\nConfidence: {confidence:.2f}%')
plt.show()
# Return top 3 predictions with confidence
top_probs, top_classes = torch.topk(probs, 3)
results = [(classes[i], p.item() * 100) for i, p in zip(top_classes, top_probs)]
print("Top 3 predictions:")
for cls, conf in results:
print(f"{cls}: {conf:.2f}%")
return predicted_class, confidence
predicted_class, confidence = predict_image(
model=model_ft,
image_path="/content/drive/My Drive/hello.jpg"
)
Top 3 predictions: white-glass: 99.46% green-glass: 0.29% plastic: 0.11%